from typing import Tuple
from collections import OrderedDict
import math
import functools

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import concurrent.futures
def merge(x, merge_token):
    B, T, P, C = x.size()
    x_id = x.clone().permute(0, 2, 3, 1).contiguous().view(B, P*C, T)
    x = torch.mean(x, dim=-1).permute(0, 2, 1)

    # Normalize
    prev = F.normalize(x[:, :, :-1], p=2, dim=-2)  # B, P*C, T-1
    after = F.normalize(x[:, :, 1:], p=2, dim=-2)  # B, P*C, T-1

    sim = torch.matmul(prev.transpose(-1, -2), after).diagonal(dim1=-2, dim2=-1)

    kernel_size = T - merge_token + 1  


    values, indices = torch.topk(sim, merge_token -1, dim=-1, largest=False, sorted=False)
    indices = torch.sort(indices, dim=-1)[0] + 1
    base = torch.zeros((B, 1), dtype = x.dtype, device = x.device)
    last = torch.ones((B, 1), dtype = x.dtype, device = x.device) * T

    indices = torch.cat((base.expand(B, -1), indices, last.expand(B, -1)), dim=1)
    cum_indices = indices[:, 1:].clone()
    indices = indices[:, 1:] - indices[:, :-1]

    weights = torch.zeros(B, merge_token * (kernel_size), dtype = x.dtype, device = x.device)
    weights = weights.view(B, merge_token, kernel_size)

    stride = torch.arange(merge_token, dtype=indices.dtype, device=indices.device).expand_as(indices)
    end_slice = cum_indices - stride
    start_slice = end_slice - indices

    start_slice = start_slice.unsqueeze(-1).expand(-1, -1, kernel_size)
    end_slice = end_slice.unsqueeze(-1).expand(-1, -1, kernel_size)

    base_index = torch.stack([torch.arange(kernel_size, device=x.device)] * merge_token)
    mask = (base_index >= start_slice) & (base_index < end_slice)
    weights[mask] = 1

    scales = mask.sum(dim=-1, keepdim=True, dtype = x.dtype).view(B, merge_token)

    # Merge
    windows = x_id.unfold(dimension=-1, size=kernel_size, step=1)  # (2, 384, 7, 7, 4, 5)

    weights = weights * (kernel_size / scales.unsqueeze(-1).expand(-1, -1, kernel_size)).to(x.dtype)

    output = windows * weights.unsqueeze(1).expand(-1, P*C, -1, -1)
    output = output.contiguous().view(B, P*C, merge_token * kernel_size)

    output = F.avg_pool1d(output, kernel_size=kernel_size, stride=kernel_size)
    
    output = output.view(B, P, C, merge_token).permute(0, 3, 1, 2)
    return output, scales, cum_indices


def unmerge(x, scale, cum_scale, T):
    B, S, P, C = x.size()
    weights = torch.empty(B, S, T, dtype=x.dtype, device=x.device).fill_(0)

    start = torch.zeros((cum_scale.size(0), 1), dtype = x.dtype, device=cum_scale.device)
    cum_scale = torch.cat((start, cum_scale), dim=1).unsqueeze(-1).expand(-1, -1, T)
    base_index = torch.arange(T, dtype=x.dtype, device=x.device).expand_as(weights)
    mask = (base_index >= cum_scale[:, :-1]) & (base_index < cum_scale[:, 1:])
    weights[mask] = 1

    x = x.view(B, S, P*C).unsqueeze(2).expand(-1, -1, T, -1)
    x = x * weights.unsqueeze(-1)
    x = torch.sum(x, dim = 1).view(B, T, P, C)
    
    return x